package drds.data_propagate.driver;

import drds.common.$;
import drds.data_propagate.driver.packets.HeaderPacket;
import drds.data_propagate.driver.packets.client.ClientAuthenticationPacket;
import drds.data_propagate.driver.packets.client.command_packet.AuthSwitchResponsePacket;
import drds.data_propagate.driver.packets.client.command_packet.QuitPacket;
import drds.data_propagate.driver.packets.server.Auth323Packet;
import drds.data_propagate.driver.packets.server.ErrorPacket;
import drds.data_propagate.driver.packets.server.HandshakeInitializationPacket;
import drds.data_propagate.driver.packets.server.command_packet.AuthSwitchRequestMoreData;
import drds.data_propagate.driver.packets.server.command_packet.AuthSwitchRequestPacket;
import drds.data_propagate.driver.socket.SocketChannel;
import drds.data_propagate.driver.socket.SocketChannelPool;
import drds.data_propagate.driver.utils.MSC;
import drds.data_propagate.driver.utils.MySQLPasswordEncrypter;
import drds.data_propagate.driver.utils.PacketManager;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.DigestException;
import java.security.NoSuchAlgorithmException;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * 基于mysql socket协议的链接实现
 */
@Slf4j
public class Connector {
    //
    public static final int timeout = 5 * 1000; // 5s
    @Setter
    @Getter
    private InetSocketAddress inetSocketAddress;
    @Setter
    @Getter
    private String username;
    @Setter
    @Getter
    private String password;
    //
    @Setter
    @Getter
    private String defaultSchema;
    @Setter
    @Getter
    private byte charsetNumber = 33;

    //
    @Setter
    @Getter
    private int receiveBufferSize = 16 * 1024;

    @Setter
    @Getter
    private SocketChannel socketChannel;
    @Setter
    @Getter
    private volatile boolean dumping = false;
    @Setter
    @Getter
    private long threadId = -1;//数据库启动的线程id
    @Setter
    @Getter
    private AtomicBoolean connected = new AtomicBoolean(false);

    public Connector() {
    }

    public Connector(InetSocketAddress inetSocketAddress, String username, String password) {
        String host = inetSocketAddress.getHostString();
        int port = inetSocketAddress.getPort();
        this.inetSocketAddress = new InetSocketAddress(host, port);
        this.username = username;
        this.password = password;
    }

    public Connector(InetSocketAddress inetSocketAddress, String username, String password, byte charsetNumber,
                     String defaultSchema) {
        this(inetSocketAddress, username, password);
        this.defaultSchema = defaultSchema;
        this.charsetNumber = charsetNumber;
    }

    public void reconnect() throws IOException {
        disconnect();
        connect();
    }

    /**
     * 如果出现读取包数据异常,请检查主从复制账号权限是否OK,以及该ip是否允许被访问,eg:denied to user 'rc'@'192.168.0.105'
     */
    public void connect() throws IOException {
        if (connected.compareAndSet(false, true)) {
            try {
                socketChannel = SocketChannelPool.open(inetSocketAddress);
                log.info("connect to {}...", inetSocketAddress);
                String id = UUID.randomUUID().toString().replaceAll("-", "").toLowerCase();
                negotiate(socketChannel, id);
                log.info("服务器交互成功,没有爆出异常(id:" + id + ")");
            } catch (Exception e) {
                if (e.getMessage().contains("recv failed ")) {
                    log.error("客户端和服务端建立tcp的短连接,每次客户端发送一次请求, 服务端响应后关闭与客户端的连接. 如果客户端在服务端关闭连接后,没有释放连接,继续试图发送请求和接收响应. 这个时候就会出错. ");
                }
                log.error($.printStackTraceToString(e));
                disconnect();
                throw new IOException("connect " + this.inetSocketAddress + " failure", e);
            }
        } else {
            log.error("the socketChannel can't be connected twice.");
        }
        log.info("连接成功");
    }

    public boolean isConnected() {
        return this.socketChannel != null && this.socketChannel.isConnected();
    }

    public void disconnect() throws IOException {
        if (!connected.compareAndSet(true, false)) {
            log.info("the socketChannel {} is not connected", this.inetSocketAddress);
        } else {
            log.info("disConnect MysqlConnection to {}...", inetSocketAddress);
            if (socketChannel != null) {
                try {
                    socketChannel.close();
                } catch (Exception e) {
                    throw new IOException("disconnect " + this.inetSocketAddress + " failure", e);
                }
            }
            // 执行一次quit
            if (dumping && threadId >= 0) {
                Connector connector = null;
                try {
                    connector = this.fork();
                    connector.connect();
                    UpdateExecutor updateExecutor = new UpdateExecutor(connector);
                    //不会出现死循环,因为Connector只有在dump情况下才需要其他线程来关闭连接
                    updateExecutor.update("KILL CONNECTION " + threadId);
                } catch (Exception e) {
                    // 忽略具体异常
                    log.info("KILL DUMP " + threadId + " failure", e);
                } finally {
                    if (connector != null) {
                        connector.disconnect();
                    }
                }

                dumping = false;
            }
        }
    }

    public Connector fork() {
        Connector connector = new Connector();
        connector.setCharsetNumber(getCharsetNumber());
        connector.setDefaultSchema(getDefaultSchema());
        connector.setInetSocketAddress(getInetSocketAddress());
        connector.setPassword(password);
        connector.setUsername(getUsername());
        connector.setReceiveBufferSize(getReceiveBufferSize());

        return connector;
    }

    public void quit() throws IOException {
        QuitPacket quitPacket = new QuitPacket();
        byte[] packetBodyBytes = quitPacket.encode();

        HeaderPacket headerPacket = new HeaderPacket();
        headerPacket.setPacketBodyLength(packetBodyBytes.length);
        headerPacket.setPacketSequenceNumber((byte) 0x00);
        PacketManager.writePackets(socketChannel, headerPacket.encode(), packetBodyBytes);
    }

    private void negotiate(SocketChannel socketChannel, String id) throws IOException {

        log.info("开始进行服务器校验(id:" + id + "),首先获取连接服务器后,服务器主动发送过来的信息");
        //该url保留 by czh
        // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol
        HeaderPacket headerPacket = PacketManager.readHeader(socketChannel, 4, timeout);
        byte[] packetBodyBytes = PacketManager.readBytes(socketChannel, headerPacket.getPacketBodyLength(), timeout);
        if (packetBodyBytes[0] < 0) {// check field_count
            if (packetBodyBytes[0] == -1) {
                ErrorPacket errorPacket = new ErrorPacket();
                errorPacket.decode(packetBodyBytes);
                throw new IOException("handshake exception:\n" + errorPacket.toString());
            } else if (packetBodyBytes[0] == -2) {
                throw new IOException("Unexpected EOF packet at handshake phase.");
            } else {
                throw new IOException("unpexpected packet with field_count=" + packetBodyBytes[0]);
            }
        }
        //
        log.info("开始进行服务器握手包解析(id:" + id + ")");
        HandshakeInitializationPacket handshakeInitializationPacket = new HandshakeInitializationPacket();
        handshakeInitializationPacket.decode(packetBodyBytes);
        if (handshakeInitializationPacket.protocolVersion != MSC.DEFAULT_PROTOCOL_VERSION) {
            // HandshakeV9
            auth323(socketChannel, (byte) (headerPacket.getPacketSequenceNumber() + 1), handshakeInitializationPacket.seed);
            return;
        }

        threadId = handshakeInitializationPacket.threadId; // 记录一下connection
        log.info("开始进行服务器认证包数据的组装(id:" + id + ")");
        ClientAuthenticationPacket clientAuthenticationPacket = new ClientAuthenticationPacket();
        clientAuthenticationPacket.setCharsetNumber(charsetNumber);

        clientAuthenticationPacket.setUsername(username);
        clientAuthenticationPacket.setPassword(password);
        clientAuthenticationPacket.setServerCapabilities(handshakeInitializationPacket.serverCapabilities);
        clientAuthenticationPacket.setDatabaseName(defaultSchema);
        clientAuthenticationPacket.setScrumbleBuff(joinAndCreateScrumbleBuff(handshakeInitializationPacket));
        clientAuthenticationPacket.setAuthPluginName("mysql_native_password".getBytes());

        byte[] clientAuthenticationPacketBodyBytes = clientAuthenticationPacket.encode();
        HeaderPacket clientAuthenticationPacketHeaderPacket = new HeaderPacket();
        clientAuthenticationPacketHeaderPacket.setPacketBodyLength(clientAuthenticationPacketBodyBytes.length);
        clientAuthenticationPacketHeaderPacket.setPacketSequenceNumber((byte) (headerPacket.getPacketSequenceNumber() + 1));
        log.info("开始发送服务器认证包(id:" + id + ")");
        PacketManager.writePackets(socketChannel, clientAuthenticationPacketHeaderPacket.encode(), clientAuthenticationPacketBodyBytes);
        log.info("发送服务器认证包成功(id:" + id + ")");

        // check auth result
        headerPacket = null;
        packetBodyBytes = null;
        headerPacket = PacketManager.readHeader(socketChannel, 4);
        packetBodyBytes = PacketManager.readBytes(socketChannel, headerPacket.getPacketBodyLength(), timeout);
        assert packetBodyBytes != null;
        if (packetBodyBytes != null) {
            log.info("接受服务器认证包响应成功(id:" + id + "),头信息如下:" + headerPacket);
        }
        log.info("接受服务器认证包响应成功(id:" + id + ")");
        //
        byte marker = packetBodyBytes[0];
        if (marker == -2 || marker == 1) {
            byte[] authData = null;
            String pluginName = null;
            if (marker == -2) {
                AuthSwitchRequestPacket authSwitchRequestPacket = new AuthSwitchRequestPacket();
                authSwitchRequestPacket.decode(packetBodyBytes);
                authData = authSwitchRequestPacket.authData;
                pluginName = authSwitchRequestPacket.pluginName;
            } else if (marker == 1) {
                AuthSwitchRequestMoreData authSwitchRequestMoreData = new AuthSwitchRequestMoreData();
                authSwitchRequestMoreData.decode(packetBodyBytes);
                authData = authSwitchRequestMoreData.authData;
            } else {
                throw new UnsupportedOperationException();
            }
            boolean isSha2Password = false;
            byte[] encryptedPassword = null;
            if (pluginName != null && "mysql_native_password".equals(pluginName)) {
                try {
                    encryptedPassword = MySQLPasswordEncrypter.scramble411(getPassword().getBytes(), authData);
                } catch (NoSuchAlgorithmException e) {
                    throw new RuntimeException("can't encrypt password that will be sent to MySQL server.", e);
                }
            } else if (pluginName != null && "caching_sha2_password".equals(pluginName)) {
                isSha2Password = true;
                try {
                    encryptedPassword = MySQLPasswordEncrypter.scrambleCachingSha2(getPassword().getBytes(), authData);
                } catch (DigestException e) {
                    throw new RuntimeException("can't encrypt password that will be sent to MySQL server.", e);
                }
            }
            assert encryptedPassword != null;
            //
            AuthSwitchResponsePacket authSwitchResponsePacket = new AuthSwitchResponsePacket();
            authSwitchResponsePacket.authData = encryptedPassword;
            byte[] authSwitchResponsePacketBodyBytes = authSwitchResponsePacket.encode();
            //
            clientAuthenticationPacketHeaderPacket = new HeaderPacket();
            clientAuthenticationPacketHeaderPacket.setPacketBodyLength(authSwitchResponsePacketBodyBytes.length);
            clientAuthenticationPacketHeaderPacket.setPacketSequenceNumber((byte) (headerPacket.getPacketSequenceNumber() + 1));
            PacketManager.writePackets(socketChannel, clientAuthenticationPacketHeaderPacket.encode(), authSwitchResponsePacketBodyBytes);
            log.info("auth switch response packet is sent out.");
            //
            headerPacket = null;
            packetBodyBytes = null;

            headerPacket = PacketManager.readHeader(socketChannel, 4);
            packetBodyBytes = PacketManager.readBytes(socketChannel, headerPacket.getPacketBodyLength(), timeout);
            assert packetBodyBytes != null;
            if (isSha2Password) {
                if (packetBodyBytes[0] == 0x01 && packetBodyBytes[1] == 0x04) {
                    // password auth failed
                    throw new IOException("caching_sha2_password Auth failed");
                }
                //需要再进行读取`
                headerPacket = null;
                packetBodyBytes = null;
                //
                headerPacket = PacketManager.readHeader(socketChannel, 4);
                packetBodyBytes = PacketManager.readBytes(socketChannel, headerPacket.getPacketBodyLength(), timeout);
            }
        }
        //error
        if (packetBodyBytes[0] < 0) {
            if (packetBodyBytes[0] == -1) {
                ErrorPacket errorPacket = new ErrorPacket();
                errorPacket.decode(packetBodyBytes);
                throw new IOException("Error When doing Client Authentication:" + errorPacket.toString());
            } else {
                throw new IOException("unpexpected packet with field_count=" + packetBodyBytes[0]);
            }
        }
    }

    private void auth323(SocketChannel channel, byte packetSequenceNumber, byte[] seed) throws IOException {
        // auth 323
        Auth323Packet auth323Packet = new Auth323Packet();
        if (password != null && password.length() > 0) {
            auth323Packet.seed = MySQLPasswordEncrypter.scramble323(password, new String(seed)).getBytes();
        }
        byte[] authPacketBodyBytes = auth323Packet.encode();
        HeaderPacket authPacketHeader = new HeaderPacket();
        authPacketHeader.setPacketBodyLength(authPacketBodyBytes.length);
        authPacketHeader.setPacketSequenceNumber((byte) (packetSequenceNumber + 1));
        log.info("client 323 authentication packet is sent out.");
        //
        PacketManager.writePackets(channel, authPacketHeader.encode(), authPacketBodyBytes);
        log.info("client 323 authentication packet is sent out.");
        // check auth result
        HeaderPacket headerPacket = PacketManager.readHeader(channel, 4);
        byte[] packetBodyBytes = PacketManager.readBytes(channel, headerPacket.getPacketBodyLength());
        assert packetBodyBytes != null;
        switch (packetBodyBytes[0]) {
            case 0:
                break;
            case -1:
                ErrorPacket errorPacket = new ErrorPacket();
                errorPacket.decode(packetBodyBytes);
                throw new IOException("Error When doing Client Authentication:" + errorPacket.toString());
            default:
                throw new IOException("unpexpected packet with field_count=" + packetBodyBytes[0]);
        }
    }

    private byte[] joinAndCreateScrumbleBuff(HandshakeInitializationPacket handshakeInitializationPacket) throws IOException {
        byte[] bytes = new byte[handshakeInitializationPacket.seed.length + handshakeInitializationPacket.restOfScrambleBuff.length];
        System.arraycopy(handshakeInitializationPacket.seed, 0, bytes, 0, handshakeInitializationPacket.seed.length);
        System.arraycopy(handshakeInitializationPacket.restOfScrambleBuff, 0, bytes, handshakeInitializationPacket.seed.length, handshakeInitializationPacket.restOfScrambleBuff.length);
        return bytes;
    }

}
